{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import torchdyn\n",
    "import torch\n",
    "from torch.autograd import grad\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchdyn.core import ODEProblem\n",
    "\n",
    "import torchdiffeq\n",
    "import time \n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.\n"
     ]
    }
   ],
   "source": [
    "f = nn.Sequential(nn.Linear(1, 32), nn.SELU(), nn.Linear(32, 32), nn.SELU(), nn.Linear(32, 1))\n",
    "prob = ODEProblem(f, solver='dopri5', sensitivity='adjoint', atol=1e-4, rtol=1e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learning `T` from a target (3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3283064365386963e-10, tensor([0.0000, 5.0000], grad_fn=<CatBackward>)\r"
     ]
    }
   ],
   "source": [
    "# torchdyn\n",
    "x = torch.randn(1, 1, requires_grad=True)\n",
    "t0 = torch.zeros(1)\n",
    "T = torch.ones(1).requires_grad_(True)\n",
    "opt = torch.optim.Adam((T,), lr=1e-2)\n",
    "\n",
    "for i in range(2000):\n",
    "    t_span = torch.cat([t0, T])\n",
    "    t_eval, traj = prob(x, t_span)\n",
    "    loss = ((t_span[-1:] - torch.tensor([5]))**2).mean()\n",
    "    print(f'{loss}, {t_span}', end='\\r')\n",
    "    loss.backward(); opt.step(); opt.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3283064365386963e-10, tensor([0.0000, 5.0000], grad_fn=<CatBackward>)\r"
     ]
    }
   ],
   "source": [
    "# torchdiffeq\n",
    "# we have to wrap for torchdiffeq\n",
    "class VectorField(nn.Module):\n",
    "    def __init__(self, f):\n",
    "        super().__init__()\n",
    "        self.f = f\n",
    "    def forward(self, t, x):\n",
    "        return self.f(x)\n",
    "    \n",
    "sys = VectorField(f)\n",
    "x = torch.randn(1, 1, requires_grad=True)\n",
    "t0 = torch.zeros(1)\n",
    "T = torch.ones(1).requires_grad_(True)\n",
    "opt = torch.optim.Adam((T,), lr=1e-2)\n",
    "\n",
    "for i in range(2000):\n",
    "    t_span = torch.cat([t0, T])\n",
    "    traj = torchdiffeq.odeint_adjoint(sys, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)\n",
    "    loss = ((t_span[-1:] - torch.tensor([5]))**2).mean()\n",
    "    print(f'{loss}, {t_span}', end='\\r')\n",
    "    loss.backward(); opt.step(); opt.zero_grad()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Explicit loss on `T`, gradcheck"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.])"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t_span = torch.cat([t0, T])\n",
    "t_eval, traj = prob(x, t_span)\n",
    "l = ((t_span[-1:] - torch.tensor([5]))**2).mean()\n",
    "dldt_torchdyn = grad(l, T)[0]\n",
    "\n",
    "t_span = torch.cat([t0, T])\n",
    "traj = torchdiffeq.odeint_adjoint(sys, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)\n",
    "l = ((t_span[-1:] - torch.tensor([5]))**2).mean()\n",
    "dldt_torchdiffeq = grad(l, T)[0]\n",
    "\n",
    "dldt_torchdyn - dldt_torchdiffeq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Explicit loss on `t0`, gradcheck"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.])"
      ]
     },
     "execution_count": 121,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t0 = torch.zeros(1).requires_grad_(True)\n",
    "T = torch.ones(1).requires_grad_(True)\n",
    "\n",
    "t_span = torch.cat([t0, T])\n",
    "t_eval, traj = prob(x, t_span)\n",
    "l = ((t_span[:1] - torch.tensor([5]))**2).mean()\n",
    "dldt_torchdyn = grad(l, t0)[0]\n",
    "\n",
    "t_span = torch.cat([t0, T])\n",
    "traj = torchdiffeq.odeint_adjoint(sys, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)\n",
    "l = ((t_span[:1] - torch.tensor([5]))**2).mean()\n",
    "dldt_torchdiffeq = grad(l, t0)[0]\n",
    "\n",
    "dldt_torchdyn - dldt_torchdiffeq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Learning `xT` by stretching `T` (fixed vector field)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: vec field is always positive so we are sure to hit the target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.\n"
     ]
    }
   ],
   "source": [
    "f = nn.Sequential(nn.Linear(1, 32), nn.SELU(), nn.Linear(32, 1), nn.Softplus())\n",
    "prob = ODEProblem(f, solver='dopri5', sensitivity='adjoint', atol=1e-4, rtol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "L: 0.00, T: 3.30, xT: 2.00\r"
     ]
    }
   ],
   "source": [
    "# torchdyn\n",
    "\n",
    "x = torch.zeros(1, 1, requires_grad=True) + 0.5\n",
    "t0 = torch.zeros(1)\n",
    "T = torch.ones(1).requires_grad_(True)\n",
    "opt = torch.optim.Adam((T,), lr=1e-2)\n",
    "\n",
    "for i in range(1000):\n",
    "    t_span = torch.cat([t0, T])\n",
    "    t_eval, traj = prob(x, t_span)\n",
    "    loss = ((traj[-1] - torch.tensor([2]))**2).mean()\n",
    "    print(f'L: {loss.item():.2f}, T: {t_span[-1].item():.2f}, xT: {traj[-1].item():.2f}', end='\\r')\n",
    "    loss.backward(); opt.step(); opt.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "L: 0.00, T: 3.89, xT: 2.00\r"
     ]
    }
   ],
   "source": [
    "class VectorField(nn.Module):\n",
    "    def __init__(self, f):\n",
    "        super().__init__()\n",
    "        self.f = f\n",
    "    def forward(self, t, x):\n",
    "        return self.f(x)\n",
    "\n",
    "sys = VectorField(f)\n",
    "x = torch.zeros(1, 1, requires_grad=True) + 0.5\n",
    "t0 = torch.zeros(1)\n",
    "T = torch.ones(1).requires_grad_(True)\n",
    "opt = torch.optim.Adam((T,), lr=1e-2)\n",
    "\n",
    "for i in range(1000):\n",
    "    t_span = torch.cat([t0, T])\n",
    "    traj = torchdiffeq.odeint_adjoint(sys, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)\n",
    "    loss = ((traj[-1] - torch.tensor([2]))**2).mean()\n",
    "    print(f'L: {loss.item():.2f}, T: {t_span[-1].item():.2f}, xT: {traj[-1].item():.2f}', end='\\r')\n",
    "    loss.backward(); opt.step(); opt.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([-2.7140]), tensor([-2.1463]))"
      ]
     },
     "execution_count": 146,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.zeros(1, 1, requires_grad=True) + 0.5\n",
    "\n",
    "t_span = torch.cat([t0, T])\n",
    "t_eval, traj = prob(x, t_span)\n",
    "l = ((traj[-1] - torch.tensor([5]))**2).mean()\n",
    "dldt_torchdyn = grad(l, T)[0]\n",
    "\n",
    "t_span = torch.cat([t0, T])\n",
    "traj = torchdiffeq.odeint_adjoint(sys, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)\n",
    "l = ((traj[-1] - torch.tensor([5]))**2).mean()\n",
    "dldt_torchdiffeq = grad(l, T)[0]\n",
    "\n",
    "dldt_torchdyn, dldt_torchdiffeq"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchdyn",
   "language": "python",
   "name": "torchdyn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
